import numpy as np
import time
# from keras.datasets import mnist  # 仅用于数据加载，不用 Keras 训练

# 1. 定义激活函数
def relu(x):
    return np.maximum(0, x)

def relu_derivative(x):
    return (x > 0).astype(float)

def softmax(x):
    exps = np.exp(x - np.max(x))  # 防止溢出
    return exps / np.sum(exps, axis=0)

# 2. 卷积运算
def convolve2d(image, kernel, stride=1, padding=0):
    image_padded = np.pad(image, ((padding, padding), (padding, padding)), mode='constant', constant_values=0)
    kernel_size = kernel.shape[0]
    output_size = (image.shape[0] - kernel_size + 2 * padding) // stride + 1
    result = np.zeros((output_size, output_size))
    
    for i in range(output_size):
        for j in range(output_size):
            region = image_padded[i * stride:i * stride + kernel_size, j * stride:j * stride + kernel_size]
            result[i, j] = np.sum(region * kernel)
    
    return result

# 3. 最大池化
def max_pooling(image, pool_size=2, stride=2):
    output_size = image.shape[0] // pool_size
    result = np.zeros((output_size, output_size))
    
    for i in range(output_size):
        for j in range(output_size):
            region = image[i * stride:i * stride + pool_size, j * stride:j * stride + pool_size]
            result[i, j] = np.max(region)
    
    return result

# 4. 初始化 CNN 参数
np.random.seed(42)
conv_filter = np.random.randn(3, 3) * 0.1  # 3x3 卷积核
fc_weights = np.random.randn(13*13, 10) * 0.1  # 全连接层权重（13*13 因为 28x28 -> 26x26 卷积 -> 13x13 池化）
# 偏置的作用: 增加模型的灵活性, 允许激活函数的输出进行整体偏移, 相当于为每个类别添加一个可学习的阈值
# 为什么是10个偏置? 因为有10个类别, 每个类别都有一个偏置
# 为什么初始化为0? 因为偏置的初始值对模型的训练影响很大, 如果初始化为0, 那么所有的神经元都没有激活的机会
# 偏置初始化为0不会导致神经元对称性问题, 在训练过程中，偏置值会通过反向传播逐渐调整到合适的值
fc_bias = np.zeros(10)

# 5. 前向传播
def forward_pass(image):
    conv_out = convolve2d(image, conv_filter)
    relu_out = relu(conv_out)
    pooled_out = max_pooling(relu_out)
    flattened = pooled_out.flatten()
    fc_out = np.dot(flattened, fc_weights) + fc_bias
    output = softmax(fc_out)
    return conv_out, relu_out, pooled_out, flattened, fc_out, output

# 6. 反向传播
def backward_pass(image, label, lr=0.01):
    global conv_filter, fc_weights, fc_bias
    conv_out, relu_out, pooled_out, flattened, fc_out, output = forward_pass(image)
    
    # 计算误差
    one_hot_label = np.zeros(10)
    one_hot_label[label] = 1
    error = output - one_hot_label
    
    # 计算全连接层梯度
    d_fc_weights = np.outer(flattened, error)
    d_fc_bias = error
    
    # 计算池化层梯度
    d_flattened = np.dot(error, fc_weights.T)
    d_pooled = d_flattened.reshape(pooled_out.shape)
    
    # 计算 ReLU 反向传播
    d_relu = np.zeros_like(relu_out)
    for i in range(pooled_out.shape[0]):
        for j in range(pooled_out.shape[1]):
            patch = relu_out[i*2:i*2+2, j*2:j*2+2]
            max_val = pooled_out[i, j]
            d_relu[i*2:i*2+2, j*2:j*2+2] = (patch == max_val) * d_pooled[i, j]
    
    # 计算卷积层梯度
    d_conv_filter = np.zeros_like(conv_filter)
    for i in range(d_relu.shape[0] - 2):
        for j in range(d_relu.shape[1] - 2):
            d_conv_filter += d_relu[i:i+3, j:j+3] * image[i, j]
    
    # 参数更新
    conv_filter -= lr * d_conv_filter
    fc_weights -= lr * d_fc_weights
    fc_bias -= lr * d_fc_bias

# 7. 训练 CNN
# MNIST 数据集以 CSV 格式 存储时，通常包含 每张手写数字图像的像素值，以及对应的标签（label）。
# 每行表示一张 28×28（784 像素） 的灰度图片.
# 第一列 label：表示手写数字的真实类别（0~9）。
# 后面的 784 列 pixel0~pixel783：表示手写数字图像的像素值。每个像素值是一个 0~255 的整数，表示该像素的灰度级别。

# (x_train, y_train), (x_test, y_test) = mnist.load_data()
# x_train = x_train[:1000] / 255.0  # 仅取 1000 张图片训练，加快计算
# y_train = y_train[:1000]

# 7. 训练 CNN
# 从CSV文件加载数据
# train_data = np.loadtxt('mnist_train_100.csv', delimiter=',')
# test_data = np.loadtxt('mnist_test_10.csv', delimiter=',')

# # 分离特征和标签
# x_train = train_data[:1000, :-1].reshape(-1, 28, 28) / 255.0  # 仅取1000张图片训练
# y_train = train_data[:1000, -1].astype(int)

# x_test = test_data[:, :-1].reshape(-1, 28, 28) / 255.0
# y_test = test_data[:, -1].astype(int)

start_time = time.time()
# 在这行代码中，除以255是进行数据归一化（normalization）处理，原因如下
# 1. MNIST数据集中的图像是灰度图像, 每个像素值的范围是0-255（8位无符号整数）,0表示黑色，255表示白色
# 2. 归一化的好处： 将所有像素值缩放到[0,1]区间，有助于神经网络的训练稳定性， 防止在计算梯度时出现数值过大或过小的问题
# 3. 除以255后，原来的0仍然是0（0/255 = 0）， 原来的255变成1（255/255 = 1）， 其他值都被等比例缩放到0-1之间
# 训练图像数据, 每张图像的像素大小(28×28)
x_train = np.load('x_train.npy')[:2000] / 255.0  # 仅取1000张图片训练
# 训练标签数据, 包含0-9的数字标签, 对应每张训练图像的真实数字
y_train = np.load('y_train.npy')[:2000]
train_load_time = time.time() - start_time
print(f"训练数据加载时间: {train_load_time:.2f}秒")
print(f"训练数据集数量: {len(x_train)}张图片")

start_time = time.time()
x_test = np.load('x_test.npy')[:100] # 仅取100张图片测试
y_test = np.load('y_test.npy')[:100]
test_load_time = time.time() - start_time
print(f"测试数据加载时间: {test_load_time:.2f}秒")
print(f"测试数据集数量: {len(x_test)}张图片")

for epoch in range(5):  # 总共训练5轮
    loss = 0           # 初始化当前轮次的总损失
    correct = 0        # 初始化当前轮次预测正确的样本数
    
    # 遍历训练集中的每张图片
    for i in range(len(x_train)):
        image, label = x_train[i], y_train[i]  # 获取当前图片和其标签
        
        # 前向传播，获取预测结果
        _, _, _, _, _, output = forward_pass(image)
        # 反向传播，更新模型参数
        backward_pass(image, label)
        
        # 计算交叉熵损失
        loss += -np.log(output[label])
        # 统计预测正确的样本数
        correct += (np.argmax(output) == label)
    
    # 打印当前轮次的平均损失和准确率
    print(f"当前轮次 {epoch+1}: 平均损失={loss/len(x_train):.4f}, 准确率={correct/len(x_train):.4f}")

# 8. 测试 CNN
def predict(image):
    _, _, _, _, _, output = forward_pass(image)
    return np.argmax(output)

correct = 0
for i in range(len(x_test)):
    if predict(x_test[i]) == y_test[i]:
        correct += 1

print(f"Test Accuracy: {correct / len(x_test):.2f}")

